import pandas as pd
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.optimize import curve_fit
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import warnings
import requests # Import the requests library
warnings.filterwarnings('ignore')

# ============================================================================
# 1. LOAD EXTENDED INTERNATIONAL DATABASE
# ============================================================================
class IAEAExtendedDatabase:
    """Load and prepare extended international database"""

    @staticmethod
    def load_extended_database():
        """Load IAEA database with original nuclei added"""
        # Load IAEA database
        iaea_url = "https://nds.iaea.org/radii/charge_radii.csv"

        # Add headers to simulate a browser request to avoid 403 Forbidden error
        headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'}
        req = requests.get(iaea_url, headers=headers, stream=True)
        req.raise_for_status() # Raise an exception for HTTP errors

        df_iaea = pd.read_csv(req.raw)

        # Use preliminary value if available
        df_iaea['R_e_final'] = df_iaea['radius_preliminary_val'].combine_first(df_iaea['radius_val'])
        df_iaea['u_R_e_final'] = df_iaea['radius_preliminary_unc'].combine_first(df_iaea['radius_unc'])

        # Add nucleus identifier
        df_iaea['nucleus'] = df_iaea['symbol'] + '-' + df_iaea['a'].astype(str)

        # Original data with muonic measurements (from paper)
        original_muonic_data = {
            'p': {'A': 1, 'Z': 1, 'R_e': 0.8751, 'u_R_e': 0.0061,
                  'R_mu': 0.8409, 'u_R_mu': 0.0004, 'source': 'CODATA/Pohl2010'},
            'd': {'A': 2, 'Z': 1, 'R_e': 2.1250, 'u_R_e': 0.0030,
                  'R_mu': 2.12758, 'u_R_mu': 0.00078, 'source': 'Krauth2021'},
            'He-3': {'A': 3, 'Z': 2, 'R_e': 1.9660, 'u_R_e': 0.0150,
                     'R_mu': 1.97007, 'u_R_mu': 0.00094, 'source': 'Schuhmann2025'},
            'He-4': {'A': 4, 'Z': 2, 'R_e': 1.6810, 'u_R_e': 0.0040,
                     'R_mu': 1.67824, 'u_R_mu': 0.00083, 'source': 'Schuhmann2025'},
            'C-13': {'A': 13, 'Z': 6, 'R_e': 2.4927, 'u_R_e': 0.0012,
                     'R_mu': 2.4829, 'u_R_mu': 0.0019, 'source': 'Abe2025'}
        }

        # Combine data
        data_list = []

        # Add IAEA data
        for _, row in df_iaea.iterrows():
            if pd.notna(row['R_e_final']):
                data_list.append({
                    'nucleus': row['nucleus'],
                    'A': int(row['a']),
                    'Z': int(row['z']),
                    'N': int(row['n']),
                    'R_e': float(row['R_e_final']),
                    'u_R_e': float(row['u_R_e_final']) if pd.notna(row['u_R_e_final']) else 0.001,
                    'R_mu': np.nan,
                    'u_R_mu': np.nan,
                    'has_muonic': False,
                    'source': 'IAEA Database'
                })

        # Add original data (replace duplicates)
        for name, data in original_muonic_data.items():
            # Remove if already in list
            data_list = [d for d in data_list if not (d['A'] == data['A'] and d['Z'] == data['Z'])]

            # Add original data
            data_list.append({
                'nucleus': name,
                'A': data['A'],
                'Z': data['Z'],
                'N': data['A'] - data['Z'],
                'R_e': data['R_e'],
                'u_R_e': data['u_R_e'],
                'R_mu': data['R_mu'],
                'u_R_mu': data['u_R_mu'],
                'has_muonic': True,
                'source': data['source']
            })

        df = pd.DataFrame(data_list)
        print(f"✅ Database loaded: {len(df)} nuclei")
        print(f"   - With muonic measurements: {df['has_muonic'].sum()}")
        print(f"   - Without muonic measurements: {len(df) - df['has_muonic'].sum()}")

        return df

# ============================================================================
# 2. APPLY SCALING LAW WITH ENHANCEMENTS
# ============================================================================
class EnhancedScalingLawAnalyzer:
    """Extended scaling law analyzer with multiple tests"""

    def __init__(self):
        self.constants = {
            'm_e': 0.5109989461,  # MeV
            'm_mu': 105.6583745,   # MeV
            'm_tau': 1776.86       # MeV
        }
        self.k_p = None
        self.results_df = None

    def calibrate_from_proton(self):
        """Calibrate k_p from proton data (as in paper)"""
        R_e_p = 0.8751
        u_R_e_p = 0.0061
        R_mu_p = 0.8409
        u_R_mu_p = 0.0004

        mass_factor = (1/self.constants['m_e'] - 1/self.constants['m_mu'])
        k_p = (R_e_p - R_mu_p) / mass_factor
        u_k_p = np.sqrt(u_R_e_p**2 + u_R_mu_p**2) / abs(mass_factor)

        self.k_p = {'value': k_p, 'uncertainty': u_k_p}
        print(f"\n🔧 Calibrated k_p from proton:")
        print(f"   k_p = {k_p:.6f} ± {u_k_p:.6f} fm·MeV")
        print(f"   mass_factor = {mass_factor:.6f} MeV⁻¹")

        return self.k_p

    def apply_scaling_law(self, df, scaling_type='A13'):
        """
        Apply scaling law to complete database

        scaling_type: 'A13', 'Z', 'Z23', 'mixed', 'optimal'
        """
        if self.k_p is None:
            self.calibrate_from_proton()

        k_p = self.k_p['value']
        mass_factor = (1/self.constants['m_e'] - 1/self.constants['m_mu'])

        results = []

        for _, row in df.iterrows():
            A, Z, R_e, u_R_e = row['A'], row['Z'], row['R_e'], row['u_R_e']
            has_muonic = row['has_muonic']
            R_mu_exp = row['R_mu'] if has_muonic else np.nan

            # Calculate scaling factor based on type
            if scaling_type == 'A13':
                scaling_factor = A**(1/3)
            elif scaling_type == 'Z':
                scaling_factor = Z
            elif scaling_type == 'Z23':
                scaling_factor = Z**(2/3)
            elif scaling_type == 'mixed':
                scaling_factor = 0.7 * A**(1/3) + 0.3 * Z
            else:
                scaling_factor = A**(1/3)  # Default

            # Calculate predicted R_mu
            R_mu_pred = R_e - k_p * scaling_factor * mass_factor

            # Calculate prediction uncertainty
            u_R_mu_pred = np.sqrt(u_R_e**2 + (scaling_factor * mass_factor * self.k_p['uncertainty'])**2)

            # Calculate required R_e (if R_mu_exp is known)
            if has_muonic and not np.isnan(R_mu_exp):
                R_e_req = R_mu_exp + k_p * scaling_factor * mass_factor
                delta_R_e = R_e_req - R_e
                u_delta_R_e = np.sqrt(row['u_R_mu']**2 + u_R_e**2 +
                                     (scaling_factor * mass_factor * self.k_p['uncertainty'])**2)

                # Calculate statistical significance
                significance = abs(delta_R_e) / u_delta_R_e if u_delta_R_e > 0 else 0
            else:
                R_e_req = np.nan
                delta_R_e = np.nan
                u_delta_R_e = np.nan
                significance = np.nan

            # Calculate predicted delta for all nuclei
            pred_delta = R_mu_pred - R_e

            results.append({
                'nucleus': row['nucleus'],
                'A': A, 'Z': Z, 'N': row['N'],
                'R_e': R_e, 'u_R_e': u_R_e,
                'R_mu_pred': R_mu_pred, 'u_R_mu_pred': u_R_mu_pred,
                'R_mu_exp': R_mu_exp if has_muonic else np.nan,
                'u_R_mu_exp': row['u_R_mu'] if has_muonic else np.nan,
                'has_muonic': has_muonic,
                'R_e_req': R_e_req,
                'delta_R_e': delta_R_e,
                'u_delta_R_e': u_delta_R_e,
                'pred_delta': pred_delta,  # Predicted difference for all nuclei
                'significance': significance,
                'scaling_factor': scaling_factor,
                'scaling_type': scaling_type,
                'source': row['source']
            })

        self.results_df = pd.DataFrame(results)
        return self.results_df

# ============================================================================
# 3. ADVANCED PATTERN DETECTION AND ANALYSIS
# ============================================================================
class PatternDetector:
    """Advanced pattern detection in scaling law results"""

    def __init__(self, results_df):
        self.df = results_df.copy()
        # Ensure numeric types
        for col in ['pred_delta', 'R_e', 'R_mu_pred', 'A', 'Z']:
            self.df[col] = pd.to_numeric(self.df[col], errors='coerce')

    def detect_mass_number_patterns(self):
        """Analyze patterns with respect to mass number A"""
        print("\n🔍 Analyzing Mass Number (A) Patterns...")

        patterns = {}

        # 1. Linear regression: pred_delta vs A
        valid_data = self.df.dropna(subset=['pred_delta', 'A'])
        A_vals = valid_data['A'].values
        delta_vals = valid_data['pred_delta'].values

        if len(A_vals) > 2:
            slope, intercept, r_value, p_value, std_err = stats.linregress(A_vals, delta_vals)
            patterns['linear_A'] = {
                'slope': slope, 'intercept': intercept,
                'r_squared': r_value**2, 'p_value': p_value,
                'std_err': std_err
            }
            print(f"   Linear fit: ΔR = {slope:.6f}·A + {intercept:.6f}")
            print(f"   R² = {r_value**2:.4f}, p = {p_value:.3e}")

        # 2. Power law fitting: ΔR ∝ A^α
        def power_law(x, a, b):
            return a * x**b

        try:
            popt, pcov = curve_fit(power_law, A_vals, delta_vals,
                                   p0=[-0.01, 0.33], maxfev=5000)
            patterns['power_law_A'] = {
                'coefficient': popt[0], 'exponent': popt[1],
                'covariance': pcov.tolist()
            }
            print(f"   Power law: ΔR = {popt[0]:.6f}·A^{popt[1]:.4f}")
            print(f"   Exponent close to 1/3? {abs(popt[1] - 1/3):.4f}")
        except:
            patterns['power_law_A'] = None

        # 3. Bin analysis by mass ranges
        mass_bins = [0, 10, 30, 50, 100, 200]
        mass_labels = ['A≤10', '10<A≤30', '30<A≤50', '50<A≤100', 'A>100']

        self.df['mass_bin'] = pd.cut(self.df['A'], bins=mass_bins, labels=mass_labels)
        bin_stats = self.df.groupby('mass_bin').agg({
            'pred_delta': ['mean', 'std', 'count', 'min', 'max'],
            'A': 'mean'
        }).round(5)

        patterns['mass_bins'] = bin_stats
        print(f"\n   Mass bin statistics:")
        print(bin_stats.to_string())

        return patterns

    def detect_z_dependence(self):
        """Analyze patterns with respect to atomic number Z"""
        print("\n🔍 Analyzing Atomic Number (Z) Patterns...")

        patterns = {}

        valid_data = self.df.dropna(subset=['pred_delta', 'Z'])
        Z_vals = valid_data['Z'].values
        delta_vals = valid_data['pred_delta'].values

        # 1. Linear regression
        if len(Z_vals) > 2:
            slope, intercept, r_value, p_value, std_err = stats.linregress(Z_vals, delta_vals)
            patterns['linear_Z'] = {
                'slope': slope, 'intercept': intercept,
                'r_squared': r_value**2, 'p_value': p_value
            }
            print(f"   Linear fit: ΔR = {slope:.6f}·Z + {intercept:.6f}")
            print(f"   R² = {r_value**2:.4f}, p = {p_value:.3e}")

        # 2. Analyze by element groups
        element_groups = {
            'Light': self.df[self.df['Z'] <= 10],
            'Medium': self.df[(self.df['Z'] > 10) & (self.df['Z'] <= 30)],
            'Heavy': self.df[self.df['Z'] > 30]
        }

        group_stats = {}
        for name, group in element_groups.items():
            if len(group) > 0:
                mean_delta = group['pred_delta'].mean()
                std_delta = group['pred_delta'].std()
                group_stats[name] = {
                    'mean_ΔR': mean_delta,
                    'std_ΔR': std_delta,
                    'count': len(group),
                    'mean_Z': group['Z'].mean()
                }
                print(f"   {name} elements (Z≤{10 if name=='Light' else 30 if name=='Medium' else '>30'}):")
                print(f"     Mean ΔR = {mean_delta:.5f} fm, Std = {std_delta:.5f} fm")

        patterns['element_groups'] = group_stats

        return patterns

    def detect_nucleon_ratio_patterns(self):
        """Analyze patterns with respect to N/Z ratio"""
        print("\n🔍 Analyzing N/Z Ratio Patterns...")

        patterns = {}

        # Calculate N/Z ratio
        self.df['N_over_Z'] = self.df['N'] / self.df['Z']
        self.df['N_over_Z'] = self.df['N_over_Z'].replace([np.inf, -np.inf], np.nan)

        valid_data = self.df.dropna(subset=['pred_delta', 'N_over_Z'])

        if len(valid_data) > 2:
            # 1. Scatter analysis
            nz_vals = valid_data['N_over_Z'].values
            delta_vals = valid_data['pred_delta'].values

            # 2. Correlation analysis
            pearson_r, pearson_p = stats.pearsonr(nz_vals, delta_vals)
            spearman_r, spearman_p = stats.spearmanr(nz_vals, delta_vals)

            patterns['correlations'] = {
                'pearson': {'r': pearson_r, 'p': pearson_p},
                'spearman': {'r': spearman_r, 'p': spearman_p}
            }

            print(f"   Pearson correlation: r = {pearson_r:.4f}, p = {pearson_p:.3e}")
            print(f"   Spearman correlation: r = {spearman_r:.4f}, p = {spearman_p:.3e}")

            # 3. Bin by N/Z ratio
            nz_bins = [0.5, 1.0, 1.2, 1.4, 2.0, 3.0]
            nz_labels = ['N/Z≤1.0', '1.0<N/Z≤1.2', '1.2<N/Z≤1.4', '1.4<N/Z≤2.0', 'N/Z>2.0']

            valid_data['nz_bin'] = pd.cut(valid_data['N_over_Z'], bins=nz_bins, labels=nz_labels)
            bin_stats = valid_data.groupby('nz_bin').agg({
                'pred_delta': ['mean', 'std', 'count'],
                'N_over_Z': 'mean'
            }).round(5)

            patterns['nz_bins'] = bin_stats
            print(f"\n   N/Z ratio statistics:")
            print(bin_stats.to_string())

        return patterns

    def detect_magic_number_effects(self):
        """Check for effects near magic numbers"""
        print("\n🔍 Analyzing Magic Number Effects...")

        magic_numbers = [2, 8, 20, 28, 50, 82, 126]
        patterns = {}

        for magic in magic_numbers:
            # Nuclei with Z or N near magic numbers
            near_magic = self.df[
                (abs(self.df['Z'] - magic) <= 2) |
                (abs(self.df['N'] - magic) <= 2)
            ]

            if len(near_magic) > 0:
                mean_delta = near_magic['pred_delta'].mean()
                std_delta = near_magic['pred_delta'].std()

                patterns[f'magic_{magic}'] = {
                    'mean_ΔR': mean_delta,
                    'std_ΔR': std_delta,
                    'count': len(near_magic),
                    'nuclei': near_magic['nucleus'].tolist()[:5]  # First 5
                }

                print(f"   Near magic number {magic}:")
                print(f"     Mean ΔR = {mean_delta:.5f} fm, Std = {std_delta:.5f} fm")
                print(f"     Sample nuclei: {', '.join(near_magic['nucleus'].tolist()[:3])}")

        return patterns

    def detect_clusters(self, n_clusters=4):
        """Use clustering to find natural groups in the data"""
        print(f"\n🔍 Detecting Natural Clusters (K-means, k={n_clusters})...")

        # Prepare features for clustering
        features = self.df[['A', 'Z', 'pred_delta', 'R_e']].dropna()

        if len(features) > n_clusters:
            # Standardize features
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(features)

            # Apply K-means clustering
            kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
            clusters = kmeans.fit_predict(X_scaled)

            # Add cluster labels to dataframe
            features_df = features.copy()
            features_df['cluster'] = clusters

            # Analyze each cluster
            cluster_stats = {}
            for cluster_id in range(n_clusters):
                cluster_data = features_df[features_df['cluster'] == cluster_id]
                cluster_stats[cluster_id] = {
                    'size': len(cluster_data),
                    'mean_A': cluster_data['A'].mean(),
                    'mean_Z': cluster_data['Z'].mean(),
                    'mean_ΔR': cluster_data['pred_delta'].mean(),
                    'mean_R_e': cluster_data['R_e'].mean(),
                    'sample_nuclei': self.df.loc[cluster_data.index, 'nucleus'].tolist()[:3]
                }

                print(f"   Cluster {cluster_id}:")
                print(f"     Size: {len(cluster_data)} nuclei")
                print(f"     Mean A: {cluster_data['A'].mean():.1f}, Z: {cluster_data['Z'].mean():.1f}")
                print(f"     Mean ΔR: {cluster_data['pred_delta'].mean():.5f} fm")
                print(f"     Sample: {', '.join(cluster_stats[cluster_id]['sample_nuclei'])}")

            return cluster_stats
        else:
            print("   Insufficient data for clustering")
            return None

    def comprehensive_pattern_report(self):
        """Generate comprehensive pattern analysis report"""
        print("\n" + "="*60)
        print("COMPREHENSIVE PATTERN ANALYSIS REPORT")
        print("="*60)

        report = {}

        # Run all pattern analyses
        report['mass_patterns'] = self.detect_mass_number_patterns()
        report['z_patterns'] = self.detect_z_dependence()
        report['nz_patterns'] = self.detect_nucleon_ratio_patterns()
        report['magic_patterns'] = self.detect_magic_number_effects()
        report['clusters'] = self.detect_clusters(n_clusters=4)

        # Summary statistics
        summary = {
            'total_nuclei': len(self.df),
            'mean_pred_delta': self.df['pred_delta'].mean(),
            'std_pred_delta': self.df['pred_delta'].std(),
            'min_pred_delta': self.df['pred_delta'].min(),
            'max_pred_delta': self.df['pred_delta'].max(),
            'median_pred_delta': self.df['pred_delta'].median(),
            'q1_pred_delta': self.df['pred_delta'].quantile(0.25),
            'q3_pred_delta': self.df['pred_delta'].quantile(0.75)
        }

        report['summary'] = summary

        print(f"\n📊 SUMMARY STATISTICS:")
        print(f"   Total nuclei analyzed: {summary['total_nuclei']}")
        print(f"   Mean predicted ΔR: {summary['mean_pred_delta']:.5f} fm")
        print(f"   Std of predicted ΔR: {summary['std_pred_delta']:.5f} fm")
        print(f"   Range: [{summary['min_pred_delta']:.5f}, {summary['max_pred_delta']:.5f}] fm")
        print(f"   Median: {summary['median_pred_delta']:.5f} fm")
        print(f"   IQR: [{summary['q1_pred_delta']:.5f}, {summary['q3_pred_delta']:.5f}] fm")

        # Save detailed report
        self.save_pattern_report(report)

        return report

    def save_pattern_report(self, report):
        """Save pattern analysis to files"""
        # Convert report to DataFrame for saving
        pattern_data = []

        for pattern_type, patterns in report.items():
            if pattern_type != 'summary':
                if isinstance(patterns, dict):
                    for key, value in patterns.items():
                        if isinstance(value, dict):
                            pattern_data.append({
                                'pattern_type': pattern_type,
                                'pattern_key': key,
                                'details': str(value)[:500]  # Truncate long strings
                            })

        pattern_df = pd.DataFrame(pattern_data)
        pattern_df.to_csv('pattern_analysis_details.csv', index=False, encoding='utf-8')

        # Save summary
        summary_df = pd.DataFrame([report['summary']])
        summary_df.to_csv('pattern_analysis_summary.csv', index=False, encoding='utf-8')

        print(f"\n💾 Pattern analysis saved to:")
        print(f"   - pattern_analysis_details.csv")
        print(f"   - pattern_analysis_summary.csv")

# ============================================================================
# 4. VISUALIZATION WITH PATTERN HIGHLIGHTS
# ============================================================================
class PatternVisualization:
    """Create visualizations highlighting detected patterns"""

    @staticmethod
    def create_pattern_visualizations(results_df, pattern_report, save_dir='./plots/'):
        """Create comprehensive pattern visualization plots"""

        # Ensure numeric types
        results_df['pred_delta'] = pd.to_numeric(results_df['pred_delta'], errors='coerce')
        results_df['A'] = pd.to_numeric(results_df['A'], errors='coerce')
        results_df['Z'] = pd.to_numeric(results_df['Z'], errors='coerce')

        # Create figure with multiple subplots
        fig = plt.figure(figsize=(20, 16))

        # 1. Predicted ΔR vs A with mass bin coloring
        ax1 = plt.subplot(3, 3, 1)

        # Define mass bins for coloring
        mass_bins = [0, 10, 30, 50, 100, 200, 300]
        mass_labels = ['A≤10', '10<A≤30', '30<A≤50', '50<A≤100', '100<A≤200', 'A>200']
        colors = plt.cm.viridis(np.linspace(0, 1, len(mass_bins)-1))

        results_df['mass_bin_idx'] = pd.cut(results_df['A'], bins=mass_bins, labels=False)

        for i in range(len(mass_bins)-1):
            bin_data = results_df[results_df['mass_bin_idx'] == i]
            if len(bin_data) > 0:
                ax1.scatter(bin_data['A'], bin_data['pred_delta'],
                          alpha=0.6, s=20, color=colors[i], label=mass_labels[i])

        ax1.set_xlabel('Mass Number A')
        ax1.set_ylabel('Predicted ΔR (R_μ_pred - R_e) [fm]')
        ax1.set_title('Predicted ΔR vs Mass Number (Colored by Mass Bin)')
        ax1.legend(loc='upper left', fontsize=8)
        ax1.grid(True, alpha=0.3)

        # 2. Predicted ΔR vs Z with element group coloring
        ax2 = plt.subplot(3, 3, 2)

        # Define element groups
        results_df['element_group'] = pd.cut(results_df['Z'],
                                           bins=[0, 10, 30, 100],
                                           labels=['Light', 'Medium', 'Heavy'])
        group_colors = {'Light': 'blue', 'Medium': 'green', 'Heavy': 'red'}

        for group, color in group_colors.items():
            group_data = results_df[results_df['element_group'] == group]
            if len(group_data) > 0:
                ax2.scatter(group_data['Z'], group_data['pred_delta'],
                          alpha=0.6, s=20, color=color, label=group)

        ax2.set_xlabel('Atomic Number Z')
        ax2.set_ylabel('Predicted ΔR [fm]')
        ax2.set_title('Predicted ΔR vs Atomic Number (Colored by Element Group)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        # 3. Histogram of predicted ΔR with statistical annotations
        ax3 = plt.subplot(3, 3, 3)

        pred_delta_clean = results_df['pred_delta'].dropna()
        ax3.hist(pred_delta_clean, bins=50, alpha=0.7, color='purple', edgecolor='black')

        # Add statistical lines
        mean_val = pred_delta_clean.mean()
        median_val = pred_delta_clean.median()
        std_val = pred_delta_clean.std()

        ax3.axvline(mean_val, color='red', linestyle='--', linewidth=2,
                   label=f'Mean: {mean_val:.4f} fm')
        ax3.axvline(median_val, color='green', linestyle='--', linewidth=2,
                   label=f'Median: {median_val:.4f} fm')
        ax3.axvline(mean_val - std_val, color='orange', linestyle=':', alpha=0.7)
        ax3.axvline(mean_val + std_val, color='orange', linestyle=':', alpha=0.7)

        ax3.set_xlabel('Predicted ΔR [fm]')
        ax3.set_ylabel('Frequency')
        ax3.set_title('Distribution of Predicted ΔR with Statistics')
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # 4. Relative ΔR vs A (ΔR/R_e)
        ax4 = plt.subplot(3, 3, 4)

        results_df['relative_delta'] = results_df['pred_delta'] / results_df['R_e'] * 100
        scatter = ax4.scatter(results_df['A'], results_df['relative_delta'],
                             c=results_df['Z'], alpha=0.6, cmap='plasma', s=20)

        ax4.axhline(y=0, color='black', linestyle='-', alpha=0.5)
        ax4.set_xlabel('Mass Number A')
        ax4.set_ylabel('Relative ΔR (%) = (R_μ_pred - R_e)/R_e × 100')
        ax4.set_title('Relative Predicted Difference vs Mass Number')
        ax4.grid(True, alpha=0.3)
        plt.colorbar(scatter, ax=ax4, label='Atomic Number Z')

        # 5. 2D histogram: A vs Z colored by ΔR
        ax5 = plt.subplot(3, 3, 5)

        # Create hexbin plot
        hb = ax5.hexbin(results_df['A'], results_df['Z'], C=results_df['pred_delta'],
                       gridsize=30, cmap='coolwarm', reduce_C_function=np.mean)

        ax5.set_xlabel('Mass Number A')
        ax5.set_ylabel('Atomic Number Z')
        ax5.set_title('2D Distribution: Mean ΔR in (A,Z) Space')
        cb = plt.colorbar(hb, ax=ax5)
        cb.set_label('Mean Predicted ΔR [fm]')

        # 6. N/Z ratio analysis
        ax6 = plt.subplot(3, 3, 6)

        # Calculate N/Z ratio
        results_df['N_over_Z'] = results_df['N'] / results_df['Z']
        results_df['N_over_Z'] = results_df['N_over_Z'].replace([np.inf, -np.inf], np.nan)

        valid_nz = results_df.dropna(subset=['N_over_Z', 'pred_delta'])

        scatter = ax6.scatter(valid_nz['N_over_Z'], valid_nz['pred_delta'],
                             c=valid_nz['A'], alpha=0.6, cmap='viridis', s=20)

        ax6.set_xlabel('Neutron-to-Proton Ratio (N/Z)')
        ax6.set_ylabel('Predicted ΔR [fm]')
        ax6.set_title('Predicted ΔR vs N/Z Ratio')
        ax6.grid(True, alpha=0.3)
        plt.colorbar(scatter, ax=ax6, label='Mass Number A')

        # 7. Comparison of scaling laws
        ax7 = plt.subplot(3, 3, 7)

        scaling_types = results_df['scaling_type'].unique()
        colors = plt.cm.Set1(np.linspace(0, 1, len(scaling_types)))

        for i, sc_type in enumerate(scaling_types):
            subset = results_df[results_df['scaling_type'] == sc_type]
            if len(subset) > 0:
                ax7.scatter(subset['A'], subset['pred_delta'],
                           alpha=0.5, s=15, color=colors[i], label=f'{sc_type}')

        ax7.set_xlabel('Mass Number A')
        ax7.set_ylabel('Predicted ΔR [fm]')
        ax7.set_title('Comparison of Different Scaling Laws')
        ax7.legend()
        ax7.grid(True, alpha=0.3)

        # 8. Magic number effects
        ax8 = plt.subplot(3, 3, 8)

        magic_numbers = [2, 8, 20, 28, 50, 82, 126]
        magic_colors = plt.cm.tab10(np.linspace(0, 1, len(magic_numbers)))

        for i, magic in enumerate(magic_numbers):
            near_magic = results_df[
                (abs(results_df['Z'] - magic) <= 2) |
                (abs(results_df['N'] - magic) <= 2)
            ]
            if len(near_magic) > 0:
                ax8.scatter(near_magic['A'], near_magic['pred_delta'],
                           alpha=0.7, s=30, color=magic_colors[i],
                           label=f'Near {magic}', edgecolors='black')

        ax8.set_xlabel('Mass Number A')
        ax8.set_ylabel('Predicted ΔR [fm]')
        ax8.set_title('Predicted ΔR Near Magic Numbers')
        ax8.legend(loc='upper right', fontsize=8)
        ax8.grid(True, alpha=0.3)

        # 9. Cumulative distribution of predicted ΔR
        ax9 = plt.subplot(3, 3, 9)

        sorted_delta = np.sort(pred_delta_clean)
        cdf = np.arange(1, len(sorted_delta) + 1) / len(sorted_delta)

        ax9.plot(sorted_delta, cdf, 'b-', linewidth=2)
        ax9.set_xlabel('Predicted ΔR [fm]')
        ax9.set_ylabel('Cumulative Probability')
        ax9.set_title('Cumulative Distribution of Predicted ΔR')
        ax9.grid(True, alpha=0.3)

        # Add percentiles
        percentiles = [10, 25, 50, 75, 90]
        for p in percentiles:
            value = np.percentile(sorted_delta, p)
            ax9.axvline(value, color='r', linestyle='--', alpha=0.5)
            ax9.text(value, 0.5, f'{p}%', rotation=90,
                    verticalalignment='center', fontsize=8)

        plt.tight_layout()
        plt.savefig(f'{save_dir}pattern_analysis_comprehensive.png', dpi=300, bbox_inches='tight')
        plt.show()

        print(f"✅ Pattern visualizations saved to {save_dir}pattern_analysis_comprehensive.png")

# ============================================================================
# 5. MAIN EXECUTION FUNCTION
# ============================================================================
def run_comprehensive_analysis():
    """Run the complete comprehensive analysis"""

    print("🚀 Starting Comprehensive Nuclear Scaling Law Analysis")
    print("=" * 60)

    try:
        # 1. Load data
        print("\n📊 Step 1: Loading international database...")
        database = IAEAExtendedDatabase()
        df = database.load_extended_database()

        # 2. Apply scaling laws
        print("\n🔧 Step 2: Applying scaling laws...")
        analyzer = EnhancedScalingLawAnalyzer()
        k_p = analyzer.calibrate_from_proton()

        # Apply different scaling types
        scaling_types = ['A13', 'Z', 'Z23', 'mixed']
        all_results = []

        for sc_type in scaling_types:
            print(f"   Applying scaling law type: {sc_type}")
            results = analyzer.apply_scaling_law(df, scaling_type=sc_type)
            all_results.append(results)

        # Combine results
        combined_results = pd.concat(all_results, ignore_index=True)

        # 3. Advanced pattern detection
        print("\n🔍 Step 3: Advanced pattern detection...")
        pattern_detector = PatternDetector(combined_results[combined_results['scaling_type'] == 'A13'])
        pattern_report = pattern_detector.comprehensive_pattern_report()

        # 4. Visualization
        print("\n🎨 Step 4: Creating pattern visualizations...")
        viz = PatternVisualization()
        viz.create_pattern_visualizations(
            combined_results[combined_results['scaling_type'] == 'A13'],
            pattern_report
        )

        # 5. Save results
        print("\n💾 Step 5: Saving all results...")

        # Save combined results
        combined_results.to_csv('full_scaling_law_results_all_types.csv', index=False, encoding='utf-8')

        # Save A13 results separately (main analysis)
        a13_results = combined_results[combined_results['scaling_type'] == 'A13']
        a13_results.to_csv('scaling_law_results_A13.csv', index=False, encoding='utf-8')

        # Save anomalies
        muonic_nuclei = a13_results[a13_results['has_muonic']]
        if len(muonic_nuclei) > 0:
            muonic_nuclei.to_csv('muonic_nuclei_analysis.csv', index=False)

        # Find nuclei with largest predicted differences
        top_predictions = a13_results.nlargest(20, 'pred_delta')[['nucleus', 'A', 'Z', 'R_e', 'R_mu_pred', 'pred_delta']]
        top_predictions.to_csv('top_predicted_differences.csv', index=False)

        print("\n✅ Analysis completed successfully!")
        print("=" * 60)
        print("\n📁 Generated files:")
        print("   1. full_scaling_law_results_all_types.csv - All scaling types")
        print("   2. scaling_law_results_A13.csv - Main A^(1/3) scaling results")
        print("   3. muonic_nuclei_analysis.csv - Analysis of nuclei with muonic data")
        print("   4. top_predicted_differences.csv - Nuclei with largest predicted ΔR")
        print("   5. pattern_analysis_details.csv - Detailed pattern analysis")
        print("   6. pattern_analysis_summary.csv - Summary of patterns")
        print("   7. plots/pattern_analysis_comprehensive.png - Pattern visualizations")

        # Display key findings
        print("\n📊 KEY FINDINGS:")
        print(f"   • Total nuclei analyzed: {len(combined_results)//4}")  # Divided by 4 scaling types
        print(f"   • Mean predicted ΔR (A13 scaling): {a13_results['pred_delta'].mean():.5f} fm")
        print(f"   • Standard deviation: {a13_results['pred_delta'].std():.5f} fm")
        print(f"   • Range: [{a13_results['pred_delta'].min():.5f}, {a13_results['pred_delta'].max():.5f}] fm")

        # Display top 5 predicted differences
        print(f"\n🔝 Top 5 largest predicted ΔR:")
        for i, (_, row) in enumerate(top_predictions.head().iterrows(), 1):
            print(f"   {i}. {row['nucleus']} (A={row['A']}, Z={row['Z']}): ΔR = {row['pred_delta']:.5f} fm")

        return combined_results, pattern_report

    except Exception as e:
        print(f"❌ Error during analysis: {e}")
        import traceback
        traceback.print_exc()
        return None, None

# ============================================================================
# 6. EXECUTE ANALYSIS
# ============================================================================
if __name__ == "__main__":
    # Create plots directory
    import os
    if not os.path.exists('./plots'):
        os.makedirs('./plots')

    # Run comprehensive analysis
    results, patterns = run_comprehensive_analysis()

    if results is not None:
        print("\n" + "="*60)
        print("ANALYSIS COMPLETE - NEXT STEPS RECOMMENDED:")
        print("="*60)
        print("1. Examine 'pattern_analysis_details.csv' for specific patterns")
        print("2. Check 'top_predicted_differences.csv' for nuclei needing")
        print("   experimental verification with muonic atoms")
        print("3. Use the visualizations to identify systematic trends")
        print("4. Consider refining the scaling law based on detected patterns")
        print("5. Focus experimental efforts on nuclei with largest predicted ΔR")